#!/usr/bin/env python
# -*- coding: utf-8 -*-

__author__ = 'James Iter'
__date__ = '15/12/16'
__contact__ = 'james.iter.cn@gmail.com'
__copyright__ = '(c) 2015 by James Iter.'

import logging
from logging.handlers import TimedRotatingFileHandler
import sys
import re

import jimit as ji
import mysql.connector
from mysql.connector import errorcode
import redis
from flask import Flask
import time

from models.rules import Rules
from state_code import *

app = Flask(__name__)
from config import *

reload(sys)
sys.setdefaultencoding('utf8')

ji.state_code.index_state['branch'] = dict(ji.state_code.index_state['branch'], **own_state_branch)


def init_logger():
    log_dir = os.path.dirname(app.config['LOG_FILE_BASE'])
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir, 0755)

    log_file_path = '.'.join([app.config['LOG_FILE_BASE'], app.config['PROCESS_TITLE']])
    _logger = logging.getLogger(log_file_path)

    if app.config['DEBUG']:
        _logger.setLevel(logging.DEBUG)
    else:
        _logger.setLevel(logging.WARNING)

    fh = TimedRotatingFileHandler(log_file_path, when="D", interval=1, backupCount=7)
    # a rollover occurs.  Current 'when' events supported:
    # S - Seconds
    # M - Minutes
    # H - Hours
    # D - Days
    # midnight - roll over at midnight
    # W{0-6} - roll over on a certain day; 0 - Monday
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(lineno)s - %(message)s')
    fh.setFormatter(formatter)
    _logger.addHandler(fh)
    return _logger

# 预编译效率更高
regex_sql_str = re.compile('\\\+"')

logger = init_logger()
origin_db_conn_map = dict()
cold_origin_db_conn_map = dict()
cold_db_conn_map_list = list()
cold_db_conn_map_list_reversed = list()
db_conn_map = dict()
cold_db_conn_map = list()
object_db_map = dict()
objects_model = {'rules': dict(), 'index': dict()}

ts = ji.Common.ts()

r = redis.StrictRedis(host=app.config['REDIS']['host'], port=app.config['REDIS']['port'],
                      db=app.config['REDIS']['db'], decode_responses=True)
try:
    r.ping()
except redis.exceptions.ResponseError:
    r = redis.StrictRedis(host=app.config['REDIS']['host'], port=app.config['REDIS']['port'],
                          db=app.config['REDIS']['db'], password=app.config['REDIS']['password'],
                          decode_responses=True)


class Init(object):
    @staticmethod
    def init_db_conn():
        global origin_db_conn_map, db_conn_map
        for item in app.config['DB_S']:
            if 'disable' in item:
                if item['disable']:
                    continue

            args_rules = [
                Rules.DOMAIN.value,
                Rules.RW_MODE.value
            ]
            if item['rw_mode'] == states.RWMode.BW_AR.value:
                args_rules.append(
                    # 最多3组读
                    (list, 'db_group_s', (2, 4))
                )
            else:
                args_rules.append(
                    # 读写均一组
                    (list, 'db_group_s', (2, 2))
                )

            ji.Check.previewing(args_rules, item)

            # 不在该函数中处理冷库关系映射
            if item['domain'] == states.DBDomain.cold.value:
                continue

            if item['domain'] not in db_conn_map:
                db_conn_map[item['domain']] = dict()
                db_conn_map[item['domain']]['db_group_s'] = list()
                db_conn_map[item['domain']]['rw_mode'] = item['rw_mode']
                db_conn_map[item['domain']]['domain'] = item['domain']

            for i, db_group in enumerate(item['db_group_s']):
                db_conn_map[item['domain']]['db_group_s'].append(list())
                for db_item in db_group:
                    args_rules = [
                        Rules.DB_HOST.value,
                        Rules.DB_USER.value,
                        Rules.DB_PASSWORD.value,
                        Rules.DB_DATABASE.value,
                        Rules.DB_POOL_SIZE.value,
                        Rules.PORT.value
                    ]

                    if 'database' not in db_item:
                        db_item['database'] = app.config['DATABASE']

                    if 'port' not in db_item:
                        db_item['port'] = app.config['DB_PORT']
                    ji.Check.previewing(args_rules, db_item)

                    origin_db_conn_map_key = ':'.join([db_item['host'], str(db_item['port']), db_item['database']])
                    if origin_db_conn_map_key not in origin_db_conn_map:
                        try:
                            cnxpool = mysql.connector.pooling.MySQLConnectionPool(
                                host=db_item['host'],
                                user=db_item['user'],
                                password=db_item['password'],
                                port=db_item['port'],
                                database=db_item['database'],
                                # raise_on_warnings=True
                                pool_size=db_item['pool_size']
                            )
                        except mysql.connector.Error as err:
                            if err.errno == errorcode.ER_ACCESS_DENIED_ERROR:
                                e_msg = 'Something is wrong with your user name or password'
                            elif err.errno == errorcode.ER_BAD_DB_ERROR:
                                e_msg = 'Database does not exist'
                            else:
                                e_msg = err.msg

                            print(e_msg)
                            logger.critical(e_msg)
                            exit(err.errno)
                        else:
                            origin_db_conn_map[origin_db_conn_map_key] = cnxpool

                    db_conn_map[item['domain']]['db_group_s'][i].append(origin_db_conn_map[origin_db_conn_map_key])

                    logger.debug(u' '.join([states.DBDomain.get_label(item['domain']),
                                            states.RWMode.get_label(item['rw_mode']), u'读' if i > 0 else u'写',
                                            u'的链接已建立!']))

    @staticmethod
    def init_cold_db_conn():
        global cold_db_conn_map, cold_db_conn_map_list, cold_db_conn_map_list_reversed
        for item in app.config['DB_S']:
            if 'disable' in item:
                if item['disable']:
                    continue

            # 该函数中只处理冷库关系映射
            if item['domain'] != states.DBDomain.cold.value:
                continue

            args_rules = [
                Rules.DOMAIN.value,
                # 读写均一组, 冷库只允许SW_SR 的 rw_mode
                (list, 'db_group_s', (2, 2))
            ]
            ji.Check.previewing(args_rules, item)

            for i, db_group in enumerate(item['db_group_s']):
                cold_db_conn_map.append(dict())
                for db_item in db_group:
                    args_rules = [
                        Rules.DB_HOST.value,
                        Rules.DB_USER.value,
                        Rules.DB_PASSWORD.value,
                        Rules.DB_DATABASE.value,
                        Rules.DB_POOL_SIZE.value,
                        Rules.PORT.value,
                        Rules.DB_SEQUENCE_FLAG.value
                    ]

                    if 'port' not in db_item:
                        db_item['port'] = app.config['DB_PORT']

                    if db_item['sequence_flag'] not in cold_db_conn_map[i]:
                        cold_db_conn_map[i][db_item['sequence_flag']] = dict()

                    for month in range(1, 13):
                        db_item['database'] = '__'.join([app.config['DATABASE'], str(db_item['sequence_flag']),
                                                         str(month)])
                        ji.Check.previewing(args_rules, db_item)

                        origin_db_conn_map_key = ':'.join([db_item['host'], str(db_item['port']), db_item['database']])
                        if origin_db_conn_map_key not in cold_origin_db_conn_map:
                            try:
                                cnxpool = mysql.connector.pooling.MySQLConnectionPool(
                                    host=db_item['host'],
                                    user=db_item['user'],
                                    password=db_item['password'],
                                    port=db_item['port'],
                                    database=db_item['database'],
                                    # raise_on_warnings=True
                                    pool_size=db_item['pool_size']
                                )
                            except mysql.connector.Error as err:
                                if err.errno == errorcode.ER_ACCESS_DENIED_ERROR:
                                    e_msg = 'Something is wrong with your user name or password'
                                elif err.errno == errorcode.ER_BAD_DB_ERROR:
                                    logger.warning(''.join([u'数据库 ', origin_db_conn_map_key, u' 不存在']))
                                    continue
                                else:
                                    e_msg = err.msg

                                print(e_msg)
                                logger.critical(e_msg)
                                exit(err.errno)
                            else:
                                cold_origin_db_conn_map[origin_db_conn_map_key] = cnxpool

                        if month not in cold_db_conn_map[i][db_item['sequence_flag']]:
                            cold_db_conn_map[i][db_item['sequence_flag']][month] = list()

                        cold_db_conn_map[i][db_item['sequence_flag']][month].append(
                            cold_origin_db_conn_map[origin_db_conn_map_key])

                        logger.debug(u' '.join([states.DBDomain.get_label(item['domain']), db_item['database'],
                                                states.RWMode.get_label(item['rw_mode']), u'读' if i > 0 else u'写',
                                                u'的链接已建立!']))

        if cold_db_conn_map.__len__() >= 2:
            for k, v in cold_db_conn_map[1].items():
                for k_s, v_s in v.items():
                    item = {'sequence': k * 100 + k_s, 'cnx_s': v_s}
                    cold_db_conn_map_list.append(item)
                    cold_db_conn_map_list_reversed.append(item)

            cold_db_conn_map_list.sort(key=lambda _item: _item['sequence'])
            cold_db_conn_map_list_reversed.sort(key=lambda _item: _item['sequence'], reverse=True)

    @staticmethod
    def init_object_db_map():
        global object_db_map
        # sql_stmt for SQL statement
        sql_stmt = 'SHOW TABLES;'
        for domain_key, domain in db_conn_map.items():
            cnx = domain['db_group_s'][0][0].get_connection()
            cursor = cnx.cursor(dictionary=False, buffered=False)
            cursor.execute(sql_stmt)
            row = cursor.fetchone()
            while row is not None and row.__len__() == 1:
                # 避免后续重名对象覆盖
                if row[0] not in object_db_map:
                    object_db_map[row[0]] = db_conn_map[domain_key]
                    logger.debug(u''.join([row[0], u' 在', states.DBDomain.get_label(domain_key), u'中, 拥有',
                                           states.RWMode.get_label(domain['rw_mode']), u'!']))
                row = cursor.fetchone()

            cnx.close()

    @staticmethod
    def init_objects_model():
        global objects_model
        for _object, domain in object_db_map.items():
            cnx = domain['db_group_s'][0][0].get_connection()
            cursor = cnx.cursor(dictionary=True, buffered=False)

            # TABLE_SCHEMA(库名), TABLE_NAME(表名), COLUMN_NAME(字段名), COLUMN_DEFAULT(默认值), IS_NULLABLE(是否可为空),
            # DATA_TYPE(字段类型), COLUMN_TYPE(完整字段类型描述), CHARACTER_MAXIMUM_LENGTH(字符串最大长度), NUMERIC_PRECISION(数值类精度)
            sql_stmt = ''.join(['SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE,'
                                'NUMERIC_PRECISION, CHARACTER_MAXIMUM_LENGTH, COLUMN_TYPE, COLUMN_DEFAULT '
                                'FROM information_schema.COLUMNS '
                                'WHERE TABLE_SCHEMA LIKE "', app.config['DATABASE'], '%" AND TABLE_NAME = "', _object, '";']
                               )
            cursor.execute(sql_stmt)
            row = cursor.fetchone()
            while row is not None:
                # 避免后续重名对象覆盖
                if row['TABLE_NAME'] not in objects_model['rules']:
                    objects_model['rules'][row['TABLE_NAME']] = dict()

                if row['COLUMN_NAME'] not in objects_model['rules'][row['TABLE_NAME']]:
                    objects_model['rules'][row['TABLE_NAME']][row['COLUMN_NAME']] = Rules.get_db_field_rule(
                        field=row['COLUMN_NAME'], data_type=row['DATA_TYPE'].upper(),
                        column_type=row['COLUMN_TYPE'],
                        char_length=row['CHARACTER_MAXIMUM_LENGTH'] if row['CHARACTER_MAXIMUM_LENGTH'] is not None
                        else 0,
                        is_nullable=False if row['IS_NULLABLE'] == 'NO' else True
                    )

                row = cursor.fetchone()

            if object_db_map[_object]['domain'] == states.DBDomain.hot.value:
                if app.config['TIME_LINE_FIELD'] not in objects_model['rules'][_object]:
                    logger.critical(''.join(['时间维度字段 ', app.config['TIME_LINE_FIELD'], ' 必须出现在 ',
                                             str(_object), '表中']))
                    exit(-1)

            sql_stmt = ''.join(['SHOW INDEX FROM ', _object, ';'])
            cursor.execute(sql_stmt)
            row = cursor.fetchone()
            while row is not None:
                if row['Table'] not in objects_model['index']:
                    objects_model['index'][row['Table']] = list()

                if row['Key_name'] == 'PRIMARY':
                    # 避免重复
                    if row['Column_name'] not in objects_model['index'][row['Table']]:
                        objects_model['index'][row['Table']].append(row['Column_name'])
                else:
                    index_field_s = row['Key_name'].split('__')
                    index_field_s.sort()
                    # 完整的索引名,必定会出现在对象可索引的组合列表中.若重复,这里将检测到
                    if row['Key_name'] not in objects_model['index'][row['Table']]:
                        for i, item in enumerate(index_field_s):
                            index_field = '__'.join(index_field_s[:i + 1])
                            objects_model['index'][row['Table']].append(index_field)

                row = cursor.fetchone()

            cnx.close()

    @staticmethod
    def db_keepalived():
        def ping(label='', _cnxpool=None):
            if _cnxpool is None:
                logger.critical(''.join(['cnxpool must not None by ', label]))
                return

            try:
                _cnx = _cnxpool.get_connection()
                _cnx.ping(attempts=1, delay=0)
            except mysql.connector.errors.InterfaceError as err:
                logger.critical(err.msg)
            except mysql.connector.Error as err:
                logger.error(err)
            else:
                _cnx.close()
                # logger.debug(' '.join([ji.JITime.now_date_time(), label, 'pinged...OK']))

        while True:
            time.sleep(5)
            for domain_key, domain in db_conn_map.items():
                for i, db_group in enumerate(domain['db_group_s']):
                    for k, cnxpool in enumerate(db_group):
                        ping(label=' '.join([states.DBDomain.get_label(domain=domain_key), str(i), str(k)]),
                             _cnxpool=cnxpool)

            for instance in cold_db_conn_map_list:
                for i, cnxpool in enumerate(instance['cnx_s']):
                    ping(label=' '.join(['冷库', str(instance['sequence']), str(i)]), _cnxpool=cnxpool)
